{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 搭建多层的 LSTM 网络实现 MNIST 分类\n",
    "\n",
    "通过本例，你可以了解到单层 LSTM 的实现，多层 LSTM 的实现。输入输出数据的格式。\n",
    "\n",
    "如果你已经熟悉 MNIST 数据集的话，应该知道每张手写数字都是一个 28 * 28 的图片，在 RNN 中，我们每个时间步输入一行，一共有 28 个时间步。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use the retry module or similar alternatives.\n",
      "WARNING:tensorflow:From <ipython-input-1-fbc070ee2f3a>:15: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n",
      "WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please write your own downloading logic.\n",
      "WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.data to implement this functionality.\n",
      "Extracting ../data/MNIST_data/train-images-idx3-ubyte.gz\n",
      "WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.data to implement this functionality.\n",
      "Extracting ../data/MNIST_data/train-labels-idx1-ubyte.gz\n",
      "WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.one_hot on tensors.\n",
      "Extracting ../data/MNIST_data/t10k-images-idx3-ubyte.gz\n",
      "Extracting ../data/MNIST_data/t10k-labels-idx1-ubyte.gz\n",
      "WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n",
      "(10000, 10)\n",
      "(55000, 10)\n"
     ]
    }
   ],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings('ignore')  # 不打印 warning \n",
    "\n",
    "import tensorflow as tf\n",
    "\n",
    "# 设置GPU按需增长\n",
    "config = tf.ConfigProto()\n",
    "config.gpu_options.allow_growth = True\n",
    "sess = tf.Session(config=config)\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "# 用tensorflow 导入数据\n",
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "mnist = input_data.read_data_sets('../data/MNIST_data', one_hot=True) \n",
    "\n",
    "# 看看咱们样本的数量\n",
    "print(mnist.test.labels.shape)\n",
    "print(mnist.train.labels.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ** 一、首先设置好模型用到的各个超参数 **"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = 1e-3\n",
    "input_size = 28      # 每个时刻的输入特征是28维的，就是每个时刻输入一行，一行有 28 个像素\n",
    "timestep_size = 28   # 时序持续长度为28，即每做一次预测，需要先输入28行\n",
    "hidden_size = 256    # 隐含层的数量\n",
    "layer_num = 2        # LSTM layer 的层数\n",
    "class_num = 10       # 最后输出分类类别数量，如果是回归预测的话应该是 1\n",
    "cell_type = \"lstm\"   # lstm 或者 block_lstm\n",
    "\n",
    "X_input = tf.placeholder(tf.float32, [None, 784])\n",
    "y_input = tf.placeholder(tf.float32, [None, class_num])\n",
    "# 在训练和测试的时候，我们想用不同的 batch_size.所以采用占位符的方式\n",
    "batch_size = tf.placeholder(tf.int32, [])  # 注意类型必须为 tf.int32, batch_size = 128\n",
    "keep_prob = tf.placeholder(tf.float32, [])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ** 二、开始搭建 LSTM 模型，其实普通 RNNs 模型也一样 **"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 把784个点的字符信息还原成 28 * 28 的图片\n",
    "# 下面几个步骤是实现 RNN / LSTM 的关键\n",
    "\n",
    "# **步骤1：RNN 的输入shape = (batch_size, timestep_size, input_size) \n",
    "X = tf.reshape(X_input, [-1, 28, 28])\n",
    "\n",
    "# ** 步骤2：创建 lstm 结构\n",
    "def lstm_cell(cell_type, num_nodes, keep_prob):\n",
    "    assert(cell_type in [\"lstm\", \"block_lstm\"], \"Wrong cell type.\")\n",
    "    if cell_type == \"lstm\":\n",
    "        cell = tf.contrib.rnn.BasicLSTMCell(num_nodes)\n",
    "    else:\n",
    "        cell = tf.contrib.rnn.LSTMBlockCell(num_nodes)\n",
    "    cell = tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=keep_prob)\n",
    "    return cell\n",
    "\n",
    "mlstm_cell = tf.contrib.rnn.MultiRNNCell([lstm_cell(cell_type, hidden_size, keep_prob) for _ in range(layer_num)], state_is_tuple = True)\n",
    "\n",
    "# **步骤3：用全零来初始化state\n",
    "init_state = mlstm_cell.zero_state(batch_size, dtype=tf.float32)\n",
    "\n",
    "# **步骤4：方法一，调用 dynamic_rnn() 来让我们构建好的网络运行起来\n",
    "# ** 当 time_major==False 时， outputs.shape = [batch_size, timestep_size, hidden_size] \n",
    "# ** 所以，可以取 h_state = outputs[:, -1, :] 作为最后输出\n",
    "# ** state.shape = [layer_num, 2, batch_size, hidden_size], \n",
    "# ** 或者，可以取 h_state = state[-1][1] 作为最后输出\n",
    "# ** 最后输出维度是 [batch_size, hidden_size]\n",
    "# outputs, state = tf.nn.dynamic_rnn(mlstm_cell, inputs=X, initial_state=init_state, time_major=False)\n",
    "# h_state = state[-1][1]\n",
    "\n",
    "# # *************** 为了更好的理解 LSTM 工作原理，我们把上面 步骤6 中的函数自己来实现 ***************\n",
    "# # 通过查看文档你会发现， RNNCell 都提供了一个 __call__()函数，我们可以用它来展开实现LSTM按时间步迭代。\n",
    "# # **步骤4：方法二，按时间步展开计算\n",
    "outputs = list()\n",
    "state = init_state\n",
    "with tf.variable_scope('RNN'):\n",
    "    for timestep in range(timestep_size):\n",
    "        (cell_output, state) = mlstm_cell(X[:, timestep, :],state)\n",
    "        outputs.append(cell_output)\n",
    "h_state = outputs[-1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ** 三、最后设置 loss function 和 优化器，展开训练并完成测试 **"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step 500, train cost=0.016716, acc=0.930000; test cost=0.012351, acc=0.960500; pass 17.112618923187256s\n",
      "step 1000, train cost=0.011734, acc=0.960000; test cost=0.008155, acc=0.975400; pass 13.219411611557007s\n",
      "step 1500, train cost=0.007815, acc=0.990000; test cost=0.007215, acc=0.978000; pass 13.311408996582031s\n",
      "step 2000, train cost=0.010880, acc=0.980000; test cost=0.007247, acc=0.978800; pass 13.056209564208984s\n",
      "step 2500, train cost=0.006850, acc=0.990000; test cost=0.004864, acc=0.985400; pass 13.145941257476807s\n",
      "step 3000, train cost=0.006143, acc=0.980000; test cost=0.005402, acc=0.984600; pass 13.041409969329834s\n",
      "step 3500, train cost=0.000579, acc=1.000000; test cost=0.004992, acc=0.985700; pass 13.318841695785522s\n",
      "step 4000, train cost=0.005876, acc=0.980000; test cost=0.004882, acc=0.986000; pass 13.208428621292114s\n",
      "step 4500, train cost=0.004937, acc=0.980000; test cost=0.004194, acc=0.988400; pass 13.295379400253296s\n",
      "step 5000, train cost=0.005670, acc=0.980000; test cost=0.003884, acc=0.989100; pass 12.929810285568237s\n"
     ]
    }
   ],
   "source": [
    "############################################################################\n",
    "# 以下部分其实和之前写的多层 CNNs 来实现 MNIST 分类是一样的。\n",
    "# 只是在测试的时候也要设置一样的 batch_size.\n",
    "\n",
    "# 上面 LSTM 部分的输出会是一个 [hidden_size] 的tensor，我们要分类的话，还需要接一个 softmax 层\n",
    "# 首先定义 softmax 的连接权重矩阵和偏置\n",
    "\n",
    "import time\n",
    "\n",
    "# 开始训练和测试\n",
    "W = tf.Variable(tf.truncated_normal([hidden_size, class_num], stddev=0.1), dtype=tf.float32)\n",
    "bias = tf.Variable(tf.constant(0.1,shape=[class_num]), dtype=tf.float32)\n",
    "y_pre = tf.nn.softmax(tf.matmul(h_state, W) + bias)\n",
    "\n",
    "\n",
    "# 损失和评估函数\n",
    "cross_entropy = -tf.reduce_mean(y_input * tf.log(y_pre))\n",
    "train_op = tf.train.AdamOptimizer(lr).minimize(cross_entropy)\n",
    "\n",
    "correct_prediction = tf.equal(tf.argmax(y_pre,1), tf.argmax(y_input,1))\n",
    "accuracy = tf.reduce_mean(tf.cast(correct_prediction, \"float\"))\n",
    "\n",
    "\n",
    "sess.run(tf.global_variables_initializer())\n",
    "time0 = time.time()\n",
    "for i in range(5000):\n",
    "    _batch_size=100\n",
    "    X_batch, y_batch = mnist.train.next_batch(batch_size=_batch_size)\n",
    "    cost, acc,  _ = sess.run([cross_entropy, accuracy, train_op], feed_dict={X_input: X_batch, y_input: y_batch, keep_prob: 0.5, batch_size: _batch_size})\n",
    "    if (i+1) % 500 == 0:\n",
    "        # 分 100 个batch 迭代\n",
    "        test_acc = 0.0\n",
    "        test_cost = 0.0\n",
    "        N = 100\n",
    "        for j in range(N):\n",
    "            X_batch, y_batch = mnist.test.next_batch(batch_size=_batch_size)\n",
    "            _cost, _acc = sess.run([cross_entropy, accuracy], feed_dict={X_input: X_batch, y_input: y_batch, keep_prob: 1.0, batch_size: _batch_size})\n",
    "            test_acc += _acc\n",
    "            test_cost += _cost\n",
    "        print(\"step {}, train cost={:.6f}, acc={:.6f}; test cost={:.6f}, acc={:.6f}; pass {}s\".format(i+1, cost, acc, test_cost/N, test_acc/N, time.time() - time0))\n",
    "        time0 = time.time()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这里我们实现了一个两层的 lstm 网络实现了 MNIST 手写数字的识别，从结果来看，相当不错。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 四、可视化看看 LSTM 的是怎么做分类的"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "毕竟 LSTM 更多的是用来做时序相关的问题，要么是文本，要么是序列预测之类的，所以很难像 CNNs 一样非常直观地看到每一层中特征的变化。在这里，我想通过可视化的方式来帮助大家理解 LSTM 是怎么样一步一步地把图片正确的给分类。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(5, 784) (5, 10)\n",
      "_outputs.shape = (28, 5, 256)\n",
      "arr_state.shape = (2, 2, 5, 256)\n"
     ]
    }
   ],
   "source": [
    "# 手写的结果 shape\n",
    "_batch_size = 5\n",
    "X_batch, y_batch = mnist.test.next_batch(_batch_size)\n",
    "print(X_batch.shape, y_batch.shape)\n",
    "_outputs, _state = np.array(sess.run([outputs, state],feed_dict={\n",
    "            X_input: X_batch, y_input: y_batch, keep_prob: 1.0, batch_size: _batch_size}))\n",
    "print('_outputs.shape =', np.asarray(_outputs).shape)\n",
    "print('arr_state.shape =', np.asarray(_state).shape)\n",
    "# 可见 outputs.shape = [ batch_size, timestep_size, hidden_size]\n",
    "# state.shape = [layer_num, 2, batch_size, hidden_size]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "看下面我找了一个字符"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n",
      " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n",
      " [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n",
      " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n",
      " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]]\n"
     ]
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "print(mnist.train.labels[10:20])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们先来看看这个字符样子,上半部分还挺像 2 来的"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADbxJREFUeJzt3W+MFPUdx/HPV8S/YISihACVYpSE8OBsTmLi2dAoxv+I0Qs8opH0eCCmGmKq1ESTUjWN2jQ+MDkiehiLSoRIiH9LtAhpFCQKeBS15owQ/hQw/gkPQPj2wQ7tqbe/OXZnd/b8vl/J5XbnuzPzZcPnZmZndn7m7gIQzyllNwCgHIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQpzZzZWbG5YRAg7m7DeZ1dW35zewaM9tpZp+a2b31LAtAc1mt1/ab2TBJH0uaKWmXpE2S5rp7b2IetvxAgzVjyz9d0qfu/pm7H5H0vKRZdSwPQBPVE/7xkr7o93xXNu17zKzLzDab2eY61gWgYA3/wM/duyV1S+z2A62kni3/bkkT+z2fkE0DMATUE/5Nki4ys1+Y2WmS5khaU0xbABqt5t1+d//OzBZKel3SMEnL3P2jwjoD0FA1n+qraWUc8wMN15SLfAAMXYQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8EVfMQ3ZJkZn2SvpF0TNJ37t5eRFMAGq+u8Gd+7e4HClgOgCZitx8Iqt7wu6Q3zOx9M+sqoiEAzVHvbn+Hu+82s/MlvWlm/3L39f1fkP1R4A8D0GLM3YtZkNmDkr5190cTrylmZQCqcncbzOtq3u03s7PNbOSJx5KulrS91uUBaK56dvvHSlptZieW8zd3f62QrgA0XGG7/YNaGbv9QMM1fLcfwNBG+IGgCD8QFOEHgiL8QFCEHwiqiG/1YQg744wzkvVhw4Y1qZPmOnLkSLJ+9OjRJnVSHrb8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU5/mb4IILLkjWJ0+enKyPHDkyWZ8zZ85J93TCzJkzk/UxY8bUvOxW9uGHHybrvb29yfrBgweT9VWrViXrb731VrLeDGz5gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAobt2dGT58eLK+YMGCqrW77747Oe+5556brI8ePTpZP3z4cM315cuXJ+fN+157nh07diTrGzdurGv5jXLppZcm63n3Oejo6EjW58+fn6xn4100BLfuBpBE+IGgCD8QFOEHgiL8QFCEHwiK8ANB5Z7nN7Nlkm6QtN/dp2XTRkt6QdIkSX2SOt39y9yVtfB5/qVLlybrt99+e9XaK6+8kpx369atyfqmTZuS9bxz6Tt37kzWUbzrr78+WV+7dm2yPlTO8z8j6ZofTLtX0jp3v0jSuuw5gCEkN/zuvl7SoR9MniWpJ3vcI+nmgvsC0GC1HvOPdfc92eO9ksYW1A+AJqn7Hn7u7qljeTPrktRV73oAFKvWLf8+MxsnSdnv/dVe6O7d7t7u7u01rgtAA9Qa/jWS5mWP50l6uZh2ADRLbvjNbIWkf0qaYma7zGy+pEckzTSzTyRdlT0HMITkHvO7+9wqpSsL7qVUN910U7J+5ZXV/7lvv/12wd2g0UaNGpWs33///cn67Nmzk/XOzs6T7qnZuMIPCIrwA0ERfiAowg8ERfiBoAg/EBRDdGcuvvjiZP3YsWM1L7utrS1Zv/POO5P1xx57LFnPG046qvHjx1etrVu3LjnvmWeemawvWbIkWV+5cmWy3grY8gNBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUAzR3QT33Xdfsv7QQw8l6wcPHkzWe3p6qtbybus9ZcqUZP3AgQPJ+tNPP52s7927N1lPGTFiRLJ+4403JuuLFi2qWsu7NmLFihXJ+quvvpqsl4khugEkEX4gKMIPBEX4gaAIPxAU4QeCIvxAUJznb4IJEyYk61dccUWyfssttyTrt95660n3VJS88/gPP/xw1dqGDRuS8z777LPJ+tSpU5P11LDqedcnDGWc5weQRPiBoAg/EBThB4Ii/EBQhB8IivADQeWe5zezZZJukLTf3adl0x6U9FtJ/8lettjdX8ldWdDz/PUyS5+2PfXUxg2/cPnllyfry5cvT9YnTpxYtXb06NHkvK+99lqynncfhPfee69q7fjx48l5h7Iiz/M/I+maAab/xd3bsp/c4ANoLbnhd/f1kg41oRcATVTPMf9CM9tqZsvMbFRhHQFoilrD/6SkCyW1SdojqepgcmbWZWabzWxzjesC0AA1hd/d97n7MXc/LmmppOmJ13a7e7u7t9faJIDi1RR+MxvX7+lsSduLaQdAs+SeIzKzFZJmSBpjZrskPSBphpm1SXJJfZIWNLBHAA2QG353nzvA5Kca0AuqyLsWI+98ecopp6R3/jo6OpL1MWPGJOup7/s/8MADyXm7u7uTddSHK/yAoAg/EBThB4Ii/EBQhB8IivADQXHr7p+4a6+9NllfvHhxsp53qi9P6vbcebcsR224dTeAJMIPBEX4gaAIPxAU4QeCIvxAUIQfCKpx93xG00ybNq1qbfXq1cl5e3t7k/UXX3wxWe/s7EzWt23blqyjPGz5gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAovs8/BJx33nnJ+vbt1cdMOeuss5LzXnbZZcn6woULk/W8+wXMmDGjaq2vry85L2rD9/kBJBF+ICjCDwRF+IGgCD8QFOEHgiL8QFC53+c3s4mSlksaK8kldbv7X81stKQXJE2S1Cep092/bFyrcV111VXJ+jnnnFO1dttttyXnHTlyZLKe9339RYsWJeucy29dg9nyfydpkbtPlXSZpDvMbKqkeyWtc/eLJK3LngMYInLD7+573H1L9vgbSTskjZc0S1JP9rIeSTc3qkkAxTupY34zmyTpEknvShrr7nuy0l5VDgsADBGDvoefmY2Q9JKku9z9a7P/Xz7s7l7tun0z65LUVW+jAIo1qC2/mQ1XJfjPufuqbPI+MxuX1cdJ2j/QvO7e7e7t7t5eRMMAipEbfqts4p+StMPdH+9XWiNpXvZ4nqSXi28PQKPkfqXXzDokvSNpm6Tj2eTFqhz3vyjp55I+V+VU36GcZfGV3gFMnz49WV+/fn2y/sQTT1St9fT0VK1J0sqVK5P1vFt/5w3xjeYb7Fd6c4/53X2DpGoLu/JkmgLQOrjCDwiK8ANBEX4gKMIPBEX4gaAIPxAUQ3S3gPPPPz9ZP/3005P1w4cPV61t3LgxOe+WLVuS9SVLliTrGLrY8gNBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUAzR3QLa2tqS9XvuuafmZb/77rvJet73/b/66qua141yMEQ3gCTCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8/zATwzn+QEkEX4gKMIPBEX4gaAIPxAU4QeCIvxAULnhN7OJZvaWmfWa2Udm9rts+oNmttvMPsh+rmt8uwCKknuRj5mNkzTO3beY2UhJ70u6WVKnpG/d/dFBr4yLfICGG+xFPrkj9rj7Hkl7ssffmNkOSePraw9A2U7qmN/MJkm6RNKJe0MtNLOtZrbMzEZVmafLzDab2ea6OgVQqEFf229mIyT9Q9Kf3H2VmY2VdECSS/qjKocGt+csg91+oMEGu9s/qPCb2XBJayW97u6PD1CfJGmtu0/LWQ7hBxqssC/2mJlJekrSjv7Bzz4IPGG2pO0n2ySA8gzm0/4OSe9I2ibpeDZ5saS5ktpU2e3vk7Qg+3AwtSy2/ECDFbrbXxTCDzQe3+cHkET4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IKvcGngU7IOnzfs/HZNNaUav21qp9SfRWqyJ7u2CwL2zq9/l/tHKzze7eXloDCa3aW6v2JdFbrcrqjd1+ICjCDwRVdvi7S15/Sqv21qp9SfRWq1J6K/WYH0B5yt7yAyhJKeE3s2vMbKeZfWpm95bRQzVm1mdm27KRh0sdYiwbBm2/mW3vN220mb1pZp9kvwccJq2k3lpi5ObEyNKlvnetNuJ103f7zWyYpI8lzZS0S9ImSXPdvbepjVRhZn2S2t299HPCZvYrSd9KWn5iNCQz+7OkQ+7+SPaHc5S7/75FentQJzlyc4N6qzay9G9U4ntX5IjXRShjyz9d0qfu/pm7H5H0vKRZJfTR8tx9vaRDP5g8S1JP9rhHlf88TVelt5bg7nvcfUv2+BtJJ0aWLvW9S/RVijLCP17SF/2e71JrDfntkt4ws/fNrKvsZgYwtt/ISHsljS2zmQHkjtzcTD8YWbpl3rtaRrwuGh/4/ViHu/9S0rWS7sh2b1uSV47ZWul0zZOSLlRlGLc9kh4rs5lsZOmXJN3l7l/3r5X53g3QVynvWxnh3y1pYr/nE7JpLcHdd2e/90tarcphSivZd2KQ1Oz3/pL7+R933+fux9z9uKSlKvG9y0aWfknSc+6+Kptc+ns3UF9lvW9lhH+TpIvM7BdmdpqkOZLWlNDHj5jZ2dkHMTKzsyVdrdYbfXiNpHnZ43mSXi6xl+9plZGbq40srZLfu5Yb8drdm/4j6TpVPvH/t6Q/lNFDlb4mS/ow+/mo7N4krVBlN/CoKp+NzJf0M0nrJH0i6e+SRrdQb8+qMprzVlWCNq6k3jpU2aXfKumD7Oe6st+7RF+lvG9c4QcExQd+QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeC+i+tAYOFG7TwKgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "X3 = mnist.train.images[10]\n",
    "img3 = X3.reshape([28, 28])\n",
    "plt.imshow(img3, cmap='gray')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们看看在分类的时候，一行一行地输入，分为各个类别的概率会是什么样子的。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(28, 1, 256)\n",
      "(28, 256)\n"
     ]
    }
   ],
   "source": [
    "X3.shape = [-1, 784]\n",
    "y_batch = mnist.train.labels[0]\n",
    "y_batch.shape = [-1, class_num]\n",
    "\n",
    "X3_outputs = np.array(sess.run(outputs, feed_dict={\n",
    "            X_input: X3, y_input: y_batch, keep_prob: 1.0, batch_size: 1}))\n",
    "print(X3_outputs.shape)\n",
    "X3_outputs.shape = [28, hidden_size]\n",
    "print(X3_outputs.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD8CAYAAACb4nSYAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAABkZJREFUeJzt3UFyozgAQNHxnKSP2Rfr8zGbmeqqTsZOABmJ/942lQQL8REOOI9t2/4CoOHvqzcAgPcRfYAQ0QcIEX2AENEHCBF9gBDRBwgRfYAQ0QcIEX2AENEHCFk6+j9+/rrFBwf9+Plrm+21PNumGbf3ToztOObu4tG/Qn3CXMXBOpax7Zg++nsn4xUr1Wc/c+/vHBm7SkgLr5F5zTb/po8+AOcR/U/MdmYGOIvoc7ojb8nt/b7ZTtSvtmm27aVD9EkTX2pEn7cSWbiW6AOEiD6w21VXbq4Y9xN94FacEJ4TfVjIiAcAj/zOkruMg+hDwBW3tc54Ky2iD5Ai+gAXueJqSPQBJjTqqW7RB/ir8zcI0QcIEX2Ag1a6QhB9gBDRBwgRfYAQ0QcIEX2AENEHCHls2zJ3GgFwkJU+QIjoA4SIPkCI6AOEiD5AiOifaKUPXbqKMRqn8tHAd+KfqJzIAQDw0W2jD3CmuywiRR8gRPQBQkQfIET0P3GX9+4A/iT6ACGizxLcggvnEH3gKSfbexF9vs2qG9Yl+gAh00ffipJXXHnA100f/bsQJuBMe3si+qQ5EVOzdPQLB6wrBFZl7s5p6egD8D2iH2Ylxn/Mgw7Rn4D4Au+SjL7IvmaMrmPcGSkZ/SMckOuxzxhtpTkm+pxu7wGw0oHzyqsrpTu9Vp6bbV+LPgSMervuqqDtPaGu9LblqG0VfVjIKsFa0UonhCNEHyBE9AFCHtt2+6sZAP5lpQ8QIvoAIaIPECL6ACGiDxAi+m9SefDjCOMzjvk31kpjK/oAIaIPECL6ACGiDxAi+gAhog8QIvoAIaIPECL6ACGiPzlPUgJnEn2AENH/hJU1cFeiz/K8BQZfJ/oAIaIPECL6nG6lt1pW2lY4g+h/k0gAKxN9gBDRBwgRfYAQ0YeF+JsSRyWjP9vDPLNtz0iV18l1zLHnlo6+nctXmCd8RWXxNX30V9oJK23rK5UDgBbzeoHo72XnvrZ3jK4Y29X256vt3fu1I7/z1feO+NqofXbFXBg1tv/39SNz6JnHti1zHAFw0G1X+gB8JPoAIaIPECL6ACGiDxCydPRXuoXvTla7fXI1xnYcc3fx6D9j5wJ8dNvoA/CR6AOEiD5AiOgDhIg+QIjoA4SIPkCI6AOEiD5AiOgDhIg+QIjof5PP8wFWJvoAIaIPECL6ACGiDxAi+gAhog8QIvoAIaIPECL6fMo/lod7En2AgWZbPE0f/REDNmoVO9vOBfjTY9t0CqBi+pU+AOcRfYAQ0QcIEX2AENEHCBF9gJDbRt8Tpa8Zo3GM7VjGdr/bRh+Aj0QfIET0AUJEHyBE9AFCRP9E7igAZif6ACGiDxAi+izBw05wDtEHCBF9gBDRBwiZPvrexwU4z/TRB+A8og8QIvoAIaLPNPz95jrGvkP0AUJEHyBE9L/JZbCPRFiRfXadUeO+9+eKPsBBK51QRR/gIldcgSWj/2qgR+yE0uX1u1/nkbGt7JMrlOb8SpLRn82zg+NuB84Vr2W18du7IBk1V1Ybv2eOHGsrLWaeeWzbbfYnAC9Y6QOEiD5AiOgDhIg+QIjoA4SIPkDI9NG/0z3CzMf8Guduz5jMxmfvAPCS6E/Oagk4k+gDhIg+QIjoA4SIPkCI6AOEiD5AiOgDhIg+QIjoszwPsI1lbMfxj9FvTJiAGYg+QIjoA4SIPkCI6AOEiD7AhEbd/CH6ACGiDxAi+t/07HJrtfvwPTsAPaIP7GbRsB7R/8S7J7IV91jGFn67bfSFdE53ensMVnTb6D/jhHDMq/EbMbar7bMrtvfZ71xt/GZ0l/F7bNstXgcAX5Bc6QNUiT5AiOgDhIg+QIjoA4SIPkCI6E/APdRjGdtxjO04PloZgMNEHyBE9CfnrR/gTKIPECL6ACGiDxAi+gAhog8QIvoAIaIPECL6ACGi/yYesgJmIPoAIaIPECL6ACGiDxAi+gAhog8QIvoAIaIPECL6pHlgjhrRBwgRfZbnIy7g65LRfxUJAYHf9h4PTsZzemybfQJQkVzpA1SJPkCI6AOEiD5AiOgDhIg+QIjoA4SIPkCI6AOEiD5AiOgDhIg+QIjoA4SIPkCI6AOEiD5AiOgDhIg+QIjoA4SIPkCI6AOEiD5AiOgDhPwDjHPUNwE/Bu0AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 28 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "h_W, h_bias = sess.run([W, bias], feed_dict={\n",
    "            X_input:X3, y_input: y_batch, keep_prob: 1.0, batch_size: 1})\n",
    "h_bias = h_bias.reshape([-1, 10])\n",
    "\n",
    "bar_index = range(class_num)\n",
    "for i in range(X3_outputs.shape[0]):\n",
    "    plt.subplot(7, 4, i+1)\n",
    "    X3_h_shate = X3_outputs[i, :].reshape([-1, hidden_size])\n",
    "    pro = sess.run(tf.nn.softmax(tf.matmul(X3_h_shate, h_W) + h_bias))\n",
    "    plt.bar(bar_index, pro[0], width=0.2 , align='center')\n",
    "    plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在上面的图中，为了更清楚地看到线条的变化，我把坐标都去了，每一行显示了 4 个图，共有 7 行，表示了一行一行读取过程中，模型对字符的识别。可以看到，在只看到前面的几行像素时，模型根本认不出来是什么字符，随着看到的像素越来越多，最后就基本确定了这个字符的正确分类."
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
